"""
Thresholder

Author: Jonathon Luiten

Simply reads in a set of detection, thresholds them at a certain score threshold, and writes them out again.
"""

import os
import sys
from multiprocessing.pool import Pool
from multiprocessing import freeze_support

from trackeval.baselines import baseline_utils as butils
from trackeval.utils import get_code_path

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))

THRESHOLD = 0.2

code_path = get_code_path()
config = {
    "INPUT_FOL": os.path.join(
        code_path, "data/detections/rob_mots/{split}/non_overlap_supplied/data/"
    ),
    "OUTPUT_FOL": os.path.join(
        code_path,
        "data/detections/rob_mots/{split}/threshold_" + str(100 * THRESHOLD) + "/data/",
    ),
    "SPLIT": "train",  # valid: 'train', 'val', 'test'.
    "Benchmarks": None,  # If None, all benchmarks in SPLIT.
    "Num_Parallel_Cores": None,  # If None, run without parallel.
    "DETECTION_THRESHOLD": THRESHOLD,
}


def do_sequence(seq_file):

    # Load input data from file (e.g. provided detections)
    # data format: data['cls'][t] = {'ids', 'scores', 'im_hs', 'im_ws', 'mask_rles'}
    data = butils.load_seq(seq_file)

    # Where to accumulate output data for writing out
    output_data = []

    # Run for each class.
    for cls, cls_data in data.items():

        # Run for each timestep.
        for timestep, t_data in enumerate(cls_data):

            # Threshold detections.
            t_data = butils.threshold(t_data, config["DETECTION_THRESHOLD"])

            # Save result in output format to write to file later.
            # Output Format = [timestep ID class score im_h im_w mask_RLE]
            for i in range(len(t_data["ids"])):
                row = [
                    timestep,
                    int(t_data["ids"][i]),
                    cls,
                    t_data["scores"][i],
                    t_data["im_hs"][i],
                    t_data["im_ws"][i],
                    t_data["mask_rles"][i],
                ]
                output_data.append(row)

    # Write results to file
    out_file = seq_file.replace(
        config["INPUT_FOL"].format(split=config["SPLIT"]),
        config["OUTPUT_FOL"].format(split=config["SPLIT"]),
    )
    butils.write_seq(output_data, out_file)

    print("DONE:", seq_todo)


if __name__ == "__main__":

    # Required to fix bug in multiprocessing on windows.
    freeze_support()

    # Obtain list of sequences to run tracker for.
    if config["Benchmarks"]:
        benchmarks = config["Benchmarks"]
    else:
        benchmarks = [
            "davis_unsupervised",
            "kitti_mots",
            "youtube_vis",
            "ovis",
            "bdd_mots",
            "tao",
        ]
        if config["SPLIT"] != "train":
            benchmarks += ["waymo", "mots_challenge"]
    seqs_todo = []
    for bench in benchmarks:
        bench_fol = os.path.join(
            config["INPUT_FOL"].format(split=config["SPLIT"]), bench
        )
        seqs_todo += [os.path.join(bench_fol, seq) for seq in os.listdir(bench_fol)]

    # Run in parallel
    if config["Num_Parallel_Cores"]:
        with Pool(config["Num_Parallel_Cores"]) as pool:
            results = pool.map(do_sequence, seqs_todo)

    # Run in series
    else:
        for seq_todo in seqs_todo:
            do_sequence(seq_todo)
